--- ../../swvox/swvox/helpers.py	2024-02-14 05:16:57.541407536 -0500
+++ ../google/swvox/swvox/helpers.py	2024-02-13 10:46:24.576645015 -0500
@@ -54,7 +54,10 @@
                 leaf_key = torch.tensor([leaf_key], device=tree.data.device)
                 self.single_key = True
             leaf_node = self.tree._all_leaves()
-            leaf_node = leaf_node.__getitem__(leaf_key).T
+            if type(leaf_key) is slice:
+                leaf_node = leaf_node.__getitem__(leaf_key).T
+            else:
+                leaf_node = leaf_node.__getitem__(leaf_key.to(leaf_node.device)).T
         if isinstance(key, tuple):
             self.key = (*leaf_node, *key[1:])
         else:
@@ -89,12 +92,12 @@
         else:
             self.tree.data.data[self.key] = value
 
-    def refine(self, repeats=1):
+    def refine(self, repeats=1, copy_data_from_parent=False, return_data_positions=False):
         """
         Refine selected leaves using tree.refine
         """
         self._check_ver()
-        ret = self.tree.refine(repeats, sel=self._unique_node_key())
+        ret = self.tree.refine(repeats, sel=self._unique_node_key(), copy_data_from_parent=copy_data_from_parent, return_data_positions=return_data_positions)
         return ret
 
     @property
@@ -361,16 +364,19 @@
 def _get_c_extension():
     from warnings import warn
     try:
-        import svox.csrc as _C
+        import swvox.csrc as _C
         if not hasattr(_C, "query_vertical"):
+            exception = "swvox.csrc does not have query_vertical"
             _C = None
-    except:
+    except Exception as e:
+        exception = str(e)
         _C = None
 
     if _C is None:
-        warn("CUDA extension svox.csrc could not be loaded! " +
+        warn("CUDA extension swvox.csrc could not be loaded! " +
              "Operations will be slow.\n" +
-             "Please do not import svox in the SVOX source directory.")
+             "Please do not import swvox in the SWVOX source directory.")
+        print("Failed with exception: " + exception)
     return _C
 
 class LocalIndex:
@@ -381,6 +387,49 @@
     def __init__(self, val):
         self.val = val
 
+class WaveletType:
+    CONSTANT=0
+    HAAR=1
+    TRILINEAR=2
+    SIDE=3
+    DB2=4
+    BIOR22=5
+
+    def __init__(self, txt):
+        choices = ['haar', 'constant', 'trilinear', 'side', 'db2', 'bior22']
+        assert txt in choices, "Unknown wavelet type: " + txt
+        if txt == 'constant':
+            self.wavelet_type = WaveletType.CONSTANT
+            self.n_wavelets = 1
+            self.wavelet_name = txt
+        elif txt == 'haar':
+            self.wavelet_name = txt
+            self.wavelet_type = WaveletType.HAAR
+            self.n_wavelets = 7
+        elif txt == 'trilinear':
+            self.wavelet_name = txt
+            self.wavelet_type = WaveletType.TRILINEAR
+            self.n_wavelets = 8
+        elif txt == 'side':
+            self.wavelet_name = txt
+            self.wavelet_type = WaveletType.SIDE
+            self.n_wavelets = 2
+        elif txt == 'db2':
+            self.wavelet_name = txt
+            self.wavelet_type = WaveletType.DB2
+            self.n_wavelets = 8
+        elif txt == 'bior22':
+            self.wavelet_name = txt
+            self.wavelet_type = WaveletType.BIOR22
+            self.n_wavelets = 8
+
+
+
+
+
+    def __repr__(self):
+        return 'WaveletType({})'.format(self.wavelet_name)
+
 class DataFormat:
     RGBA = 0
     SH = 1
@@ -393,7 +442,6 @@
             self.basis_dim = int(txt[nonalph_idx:])
             format_type = txt[:nonalph_idx]
             assert self.basis_dim > 0, "data_format basis_dim must be positive"
-            self.data_dim = 3 * self.basis_dim + 1
             if format_type == "SH":
                 self.format = DataFormat.SH
                 assert int(self.basis_dim ** 0.5) ** 2 == self.basis_dim, \
@@ -407,8 +455,7 @@
                 self.format = DataFormat.RGBA
         else:
             self.format = DataFormat.RGBA
-            self.basis_dim = -1
-            self.data_dim = None
+            self.basis_dim = 1
 
     def __repr__(self):
         if self.format == DataFormat.SH:
